오늘 풀어본 문제는 백준의 1761번 문제1이다. 문제 풀이에 사용한 언어는 C++ 이다.
이 문제의 내용과 조건은 다음과 같다.
$N$ $(2 \le N \le 40,000)$ 개의 정점으로 이루어진 트리가 주어지고 $M$ $(1 \le M \le 10,000)$ 개의 두 노드 쌍을 입력받을 때 두 노드 사이의 거리를 출력하라.
첫째 줄에 노드의 개수 $N$ 이 입력되고 다음 $N-1$ 개의 줄에 트리 상에 연결된 두 점과 거리를 입력받는다. 그 다음 줄에 $M$ 이 주어지고, 다음 $M$ 개의 줄에 거리를 알고 싶은 노드 쌍이 한 줄에 한 쌍씩 입력된다. 두 점 사이의 거리는 10,000보다 작거나 같은 자연수이다.
정점은 $1$ 번부터 $N$ 번까지 번호가 매겨져 있다.
$M$ 개의 줄에 차례대로 입력받은 두 노드 사이의 거리를 출력한다.
두 정점의 거리를 구하기 위해서는, 두 정점 사이의 경로를 알아야 하는데 이 과정에서 ‘최소 공통 조상’을 반드시 거칠 수 밖에 없다는 걸 생각해냈고, 기존에 풀었던 백준 11438번 문제에서 짜뒀던 Sparse Table을 이용한 LCA를 찾는 코드를 활용하기로 했다.
이 문제에서는 단순히 LCA를 구하는 것이 아닌, 두 정점간의 거리를 구해야 하기 때문에, 각 노드에서 LCA까지의 거리를 알아내어 이를 더하는 방식으로 해결할 수 있을 것이라고 생각하였다. 그래서 LCA를 추적하는 함수에서 LCA까지의 거리도 함께 추적하도록 하였다.
코드는 다음과 같이 작성하였다.
#include <bits/extc++.h>
using namespace std;
using ll = long long int;
using ull = unsigned long long int;
using pll = pair<ll, ll>;
vector<vector<pll>> tree(40001);
vector<int> depth(40001, 0);
vector<vector<pll>> parent(40001, vector<pll>(17));
vector<bool> visited(40001, false);
void setImmediateParent(int currentNode, int currentDepth);
void setParentSparseTable(int N);
int findDistance(int nodeA, int nodeB);
int main(void) {
ios_base::sync_with_stdio(false);
cin.tie(nullptr);
cout.tie(nullptr);
int N, M, nodeA, nodeB, weight;
cin >> N;
for (int i=0; i<N-1; i++) {
cin >> nodeA >> nodeB >> weight;
tree[nodeA].emplace_back(nodeB, weight);
tree[nodeB].emplace_back(nodeA, weight);
}
setImmediateParent(1, 0);
setParentSparseTable(N);
cin >> M;
for (int i=0; i<M; i++) {
cin >> nodeA >> nodeB;
cout << findDistance(nodeA, nodeB) << '\n';
}
return 0;
}
void setImmediateParent(int curr, int currentDepth) {
visited[curr] = true;
depth[curr] = currentDepth;
for (auto [next, dist] : tree[curr]) {
if (visited[next] == false) {
parent[next][0].first = curr;
parent[next][0].second = dist;
setImmediateParent(next, currentDepth + 1);
}
}
}
void setParentSparseTable(int N) {
int maximumHeight = 0;
for (int tempN=N; tempN>1; maximumHeight++) {
tempN >>= 1;
}
for (int i=1; i<=maximumHeight; i++) {
for (int j=1; j<=N; j++) {
parent[j][i].first = parent[parent[j][i-1].first][i-1].first;
parent[j][i].second = parent[j][i-1].second + parent[parent[j][i-1].first][i-1].second;
}
}
}
int findDistance(int nodeA, int nodeB) {
if (depth[nodeA] < depth[nodeB]) {
swap(nodeA, nodeB);
}
int distFromNodeA = 0;
int distFromNodeB = 0;
int difference = depth[nodeA] - depth[nodeB];
int power = 0;
while (difference > 0) {
if (difference % 2 == 1) {
distFromNodeA += parent[nodeA][power].second;
nodeA = parent[nodeA][power].first;
}
difference >>= 1;
power++;
}
if (nodeA == nodeB) {
return distFromNodeA;
}
else {
for (int i=16; i>=0; i--) {
if (parent[nodeA][i].first != 0 && parent[nodeA][i].first != parent[nodeB][i].first) {
distFromNodeA += parent[nodeA][i].second;
distFromNodeB += parent[nodeB][i].second;
nodeA = parent[nodeA][i].first;
nodeB = parent[nodeB][i].first;
}
else {
for (int j=0; j<=i; j++) {
if (parent[nodeA][j].first == parent[nodeB][j].first) {
distFromNodeA += parent[nodeA][j].second;
distFromNodeB += parent[nodeB][j].second;
break;
}
}
break;
}
}
return distFromNodeA + distFromNodeB;
}
}
실행 결과 ‘틀렸습니다’ 가 떴다.
문제가 되었던 부분은, 마지막에 최소 공통 조상을 추적하는 과정에서 ‘최소’가 아닌 공통 조상까지의 거리를 구하게 될 수 있다는 점이었다. 이는 두 노드의 조상이 같을 때, 최소 공통 조상을 찾는 코드를 잘못 작성했기 때문이었다.
기존의 코드는 $2^i$ 차 부모가 같으면 $j$ 를 $0$ 부터 $i$ 까지 순회하며 $2^j$ 차 부모가 같아지는 최소의 $j$ 로 LCA를 찾았는데, 각 $j$ 는 $2^j$ 차 부모 밖에 표현하지 못하기 때문에 ($2$의 거듭제곱이 아닌 수) 차 부모는 찾아낼 수 없던 것이다. 그래서 $2^j$ 차 부모가 같아질 때, $2^{j-1}$ 차 부모까지의 거리를 먼저 계산하고, 그 부모들의 LCA를 다시 추적해나가는 방식으로 해결을 시도해보았다.
코드는 다음과 같이 수정하였다.
#include <bits/extc++.h>
using namespace std;
using ll = long long int;
using ull = unsigned long long int;
using pll = pair<ll, ll>;
vector<vector<pll>> tree(40001);
vector<int> depth(40001, 0);
vector<vector<pll>> parent(40001, vector<pll>(17));
vector<bool> visited(40001, false);
void setImmediateParent(int currentNode, int currentDepth);
void setParentSparseTable(int N);
int findDistance(int nodeA, int nodeB);
int main(void) {
ios_base::sync_with_stdio(false);
cin.tie(nullptr);
cout.tie(nullptr);
int N, M, nodeA, nodeB, weight;
cin >> N;
for (int i=0; i<N-1; i++) {
cin >> nodeA >> nodeB >> weight;
tree[nodeA].emplace_back(nodeB, weight);
tree[nodeB].emplace_back(nodeA, weight);
}
setImmediateParent(1, 0);
setParentSparseTable(N);
cin >> M;
for (int i=0; i<M; i++) {
cin >> nodeA >> nodeB;
cout << findDistance(nodeA, nodeB) << '\n';
}
return 0;
}
void setImmediateParent(int curr, int currentDepth) {
visited[curr] = true;
depth[curr] = currentDepth;
for (auto [next, dist] : tree[curr]) {
if (visited[next] == false) {
parent[next][0].first = curr;
parent[next][0].second = dist;
setImmediateParent(next, currentDepth + 1);
}
}
}
void setParentSparseTable(int N) {
int maximumHeight = 0;
for (int tempN=N; tempN>1; maximumHeight++) {
tempN >>= 1;
}
for (int i=1; i<=maximumHeight; i++) {
for (int j=1; j<=N; j++) {
parent[j][i].first = parent[parent[j][i-1].first][i-1].first;
parent[j][i].second = parent[j][i-1].second + parent[parent[j][i-1].first][i-1].second;
}
}
}
int findDistance(int nodeA, int nodeB) {
if (depth[nodeA] < depth[nodeB]) {
swap(nodeA, nodeB);
}
int distFromNodeA = 0;
int distFromNodeB = 0;
int difference = depth[nodeA] - depth[nodeB];
int power = 0;
while (difference > 0) {
if (difference % 2 == 1) {
distFromNodeA += parent[nodeA][power].second;
nodeA = parent[nodeA][power].first;
}
difference >>= 1;
power++;
}
if (nodeA == nodeB) {
return distFromNodeA;
}
else {
for (int i=16; i>=0; i--) {
if (parent[nodeA][i].first != 0) {
if (parent[nodeA][i].first != parent[nodeB][i].first) {
distFromNodeA += parent[nodeA][i].second;
distFromNodeB += parent[nodeB][i].second;
nodeA = parent[nodeA][i].first;
nodeB = parent[nodeB][i].first;
}
else {
for (int j=0; j<=i; j++) {
if (parent[nodeA][j].first == parent[nodeB][j].first) {
if (j <= 1) {
distFromNodeA += parent[nodeA][j].second;
distFromNodeB += parent[nodeB][j].second;
break;
}
else {
distFromNodeA += parent[nodeA][j-1].second;
distFromNodeB += parent[nodeB][j-1].second;
nodeA = parent[nodeA][j-1].first;
nodeB = parent[nodeB][j-1].first;
j = -1;
}
}
}
break;
}
}
}
return distFromNodeA + distFromNodeB;
}
}
그러자 모든 테스트 케이스를 통과하고 ‘맞았습니다’가 나오는 것을 확인할 수 있었다.
LCA 알고리즘은 상당히 오랜만에 쓰는 알고리즘이라 제대로 활용할 수 있을지 걱정했지만, 생각보다 간단하게 해결된 것 같다. CLASS 6 문제들을 풀다 보면 종종 보게될 것 같은데, 그 문제들도 쉽게 풀렸으면 좋겠다.
오늘의 PS는 여기까지!
1: https://www.acmicpc.net/problem/1761